from stable_baselines3 import PPO
import os
from customlensnakeenv_single import SnakeEnv
import time
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv

if __name__ == "__main__":
    models_dir = f"models/{int(time.time())}/"
    logdir = f"logs/{int(time.time())}/"
    # load settings to further train a saved model
    #models_dir = f"models/1704679280/"
    #logdir = f"logs/1704679280/"

    if not os.path.exists(models_dir):
        os.makedirs(models_dir)

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    board_size = 20
    total_obs = 10000 * 5 * 2
    num_steps = board_size
    num_envs = total_obs // num_steps
    min_batch_size = total_obs // 8
    env = make_vec_env(SnakeEnv, n_envs=num_envs, vec_env_cls=DummyVecEnv, env_kwargs={"render": False, "board_size": board_size})
    print(total_obs, num_steps, num_envs, min_batch_size)
    env.reset()
    model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logdir, device="cuda", n_steps=num_steps, batch_size=min_batch_size, learning_rate=3e-4, n_epochs=10, policy_kwargs={"net_arch": dict(pi=[512, 512, 512, 512], vf=[512, 512, 512, 512])})
    
    TIMESTEPS = 5000000
    iters = 0
    # can load a saved model to continue training
    # model = PPO.load(
    #     f"{models_dir}/{TIMESTEPS*iters}", env=env, verbose=1, tensorboard_log=logdir, device="cuda",
    # )
    while True:
        iters += 1
        model.learn(
            total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=f"PPO"
        )
        model.save(f"{models_dir}/{TIMESTEPS*iters}")
